from flask import Flask
from flask import request
from flask import jsonify
from flask import send_from_directory
import random
import sys
import threading
import os
import shutil
from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf

from f5_tts.infer.utils_infer import (
    infer_process,
    load_model,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
    save_spectrogram,
    transcribe,
)
from f5_tts.model.utils import seed_everything


class F5TTS:
    def __init__(
        self,
        model = "F5TTS_v1_Base",
        ckpt_file = "",
        vocab_file = "",
        ode_method = "euler",
        use_ema = True,
        vocoder_local_path = None,
        device = None,
        hf_cache_dir = None,
    ):
        model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
        model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
        model_arc = model_cfg.model.arch

        self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
        self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate

        self.ode_method = ode_method
        self.use_ema = use_ema

        if device is not None:
            self.device = device
        else:
            import torch

            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "xpu"
                if torch.xpu.is_available()
                else "mps"
                if torch.backends.mps.is_available()
                else "cpu"
            )

        # Load models
        self.vocoder = load_vocoder(
            self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
        )

        repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"

        # override for previous models
        if model == "F5TTS_Base":
            if self.mel_spec_type == "vocos":
                ckpt_step = 1200000
            elif self.mel_spec_type == "bigvgan":
                model = "F5TTS_Base_bigvgan"
                ckpt_type = "pt"
        elif model == "E2TTS_Base":
            repo_name = "E2-TTS"
            ckpt_step = 1200000

        if not ckpt_file:
            ckpt_file = str(
                cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
            )
        self.ema_model = load_model(
            model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
        )

    def transcribe(self, ref_audio, language=None):
        return transcribe(ref_audio, language)

    def export_wav(self, wav, file_wave, remove_silence = False):
        sf.write(file_wave, wav, self.target_sample_rate)

        if remove_silence:
            remove_silence_for_generated_wav(file_wave)


    def export_spectrogram(self, spec, file_spec):
        save_spectrogram(spec, file_spec)

    def infer(
        self,
        ref_file,
        ref_text,
        gen_text,
        show_info = print,
        progress = tqdm,
        target_rms=0.1,
        cross_fade_duration=0.15,
        sway_sampling_coef=-1,
        cfg_strength=2,
        nfe_step=32,
        speed=1.0,
        fix_duration=None,
        remove_silence=False,
        file_wave=None,
        file_spec=None,
        seed=None,
    ):
        if seed is None:
            seed = random.randint(0, sys.maxsize)

        seed_everything(seed)

        self.seed = seed

        ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)

        wav, sr, spec = infer_process(
            ref_file,
            ref_text,
            gen_text,
            self.ema_model,
            self.vocoder,
            self.mel_spec_type,
            show_info = show_info,
            progress = progress,
            target_rms = target_rms,
            cross_fade_duration = cross_fade_duration,
            nfe_step = nfe_step,
            cfg_strength = cfg_strength,
            sway_sampling_coef = sway_sampling_coef,
            speed = speed,
            fix_duration = fix_duration,
            device = self.device,
        )

        return wav, sr, spec


app = Flask(__name__)

lock = threading.Lock()

@app.route('/generate_audio',methods = ['POST'])
def generate_audio():
    ref_file = request.form['ref_file']
    ref_text = request.form['ref_text']
    gen_text = request.form['gen_text']
    speed = float(request.form['speed'])

    with lock:
        wav, sr, spec = f5tts.infer(
            ref_file = ref_file,
            ref_text = ref_text,
            gen_text = gen_text,
            speed = speed
        )

        sf.write('output.wav', wav, f5tts.target_sample_rate)

    return send_from_directory('.','output.wav')

if __name__ == "__main__":
    if len(sys.argv) > 1:
        f5tts = F5TTS(ckpt_file = sys.argv[1])
    else:
        f5tts = F5TTS()

    app.run()
